{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [数据集定义与加载](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/beginner/data_load_cn.html)\n",
    "\n",
    "在飞桨框架中，可通过如下两个核心步骤完成数据集的定义与加载：\n",
    "\n",
    "1. 定义数据集：将磁盘中保存的原始图片、文字等样本和对应的标签映射到 Dataset，方便后续通过索引（index）读取数据，在 Dataset 中还可以进行一些数据变换、数据增广等预处理操作。在飞桨框架中推荐使用 paddle.io.Dataset 自定义数据集，另外在 paddle.vision.datasets 和 paddle.text 目录下飞桨内置了一些经典数据集方便直接调用。\n",
    "\n",
    "2. 迭代读取数据集：自动将数据集的样本进行分批（batch）、乱序（shuffle）等操作，方便训练时迭代读取，同时还支持多进程异步读取功能可加快数据读取速度。在飞桨框架中可使用 paddle.io.DataLoader 迭代读取数据集。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/urllib3/util/selectors.py:14: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.10 it will stop working\n",
      "  from collections import namedtuple, Mapping\n",
      "/usr/lib/python3/dist-packages/urllib3/_collections.py:2: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated since Python 3.3, and in 3.10 it will stop working\n",
      "  from collections import Mapping, MutableMapping\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "计算机视觉（CV）相关数据集： ['DatasetFolder', 'ImageFolder', 'MNIST', 'FashionMNIST', 'Flowers', 'Cifar10', 'Cifar100', 'VOC2012']\n",
      "自然语言处理（NLP）相关数据集： ['Conll05st', 'Imdb', 'Imikolov', 'Movielens', 'UCIHousing', 'WMT14', 'WMT16']\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "\n",
    "print(\"计算机视觉（CV）相关数据集：\", paddle.vision.datasets.__all__)\n",
    "print(\"自然语言处理（NLP）相关数据集：\", paddle.text.__all__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train images:  60000 , test images:  10000\n"
     ]
    }
   ],
   "source": [
    "from paddle.vision.transforms import Normalize\n",
    "\n",
    "# 定义图像归一化处理方法，这里的CHW指图像格式需为 [C通道数，H图像高度，W图像宽度]\n",
    "transform = Normalize(mean=[127.5], std=[127.5], data_format=\"CHW\")\n",
    "# 下载数据集并初始化 DataSet\n",
    "train_dataset = paddle.vision.datasets.MNIST(mode=\"train\", transform=transform)\n",
    "test_dataset = paddle.vision.datasets.MNIST(mode=\"test\", transform=transform)\n",
    "print(\n",
    "    \"train images: \", len(train_dataset), \", test images: \", len(test_dataset)\n",
    ")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shape of image:  (1, 28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAPW0lEQVR4nO3de4xc9XnG8eeJvZjYmMSOg+sQFzvglGsx6cqAsIAqCiUoEqAqECuKHErqNMFJaFwJSi+QilRulRARSpFMcTEV9wSEVdEk1IpwogaXhRowEG7GNDbGxmzBXH1Zv/1jx9Fidn67zJy5eN/vR1rtzHnPmfNq7GfPmfmdmZ8jQgDGvg90ugEA7UHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQ9uRsh+03bX93lOt/p7Z+2B7f6v5QHXNRTW62Q9KciHh2n2VvSdr7n+O2iPjKkPosSc9L6omI3W1sF03gLzPqOX7oHwDs/ziNB5Ig7Khnte2XbN9VO23Hfo6wYzinSZol6UhJL0r6d96M2/8RdrxHRKyOiJ0R8aqkb0maLemoznaFZhF2jEZIcqebQHM4NcO72D5GUo+kxyR9UNKVkjZJerKTfaF5HNmxr+mSbpe0XdJ6Db52/1xE7OpkU2geF9UkZ/sdSTsk/TAi/mYU618u6duSJkiaFBEDLW4RFSHsQBKcxgNJEHYgiba+G3+AJ8SBmtTOXQKpvKM3tTN2DDtM2lTYbZ8p6WpJ4yT9S0QsLa1/oCbpRH+6mV0CKFgTq+rWGj6Ntz1O0rWSPivpaEkLbB/d6OMBaK1mXrPPk/RsRKyPiJ2SbpN0djVtAahaM2E/VNJvhtzfWFv2LrYX2e6z3bdLO5rYHYBmtPzd+IhYFhG9EdHbowmt3h2AOpoJ+yZJM4fc/3htGYAu1EzYH5Q0x/Zs2wdI+oKkldW0BaBqDQ+9RcRu24sl/VSDQ2/LI+LxyjoDUKmmxtkj4l5J91bUC4AW4nJZIAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkmhqFld0P48v/xOP++i0lu7/qb+YVbc2MHFPcdvDDt9arE/8uov1l646oG7t4d7bi9tuG3izWD/xziXF+hHffqBY74Smwm57g6TXJQ1I2h0RvVU0BaB6VRzZ/zAitlXwOABaiNfsQBLNhj0k/cz2Q7YXDbeC7UW2+2z37dKOJncHoFHNnsbPj4hNtg+RdJ/tX0fE6qErRMQyScsk6WBPjSb3B6BBTR3ZI2JT7fdWSXdLmldFUwCq13DYbU+yPXnvbUlnSFpXVWMAqtXMafx0SXfb3vs4t0TETyrpaowZd9ScYj0m9BTrL5724WL97ZPqjwlP/VB5vPgXx5fHmzvpP96aXKz/wz+dWayvOe6WurXnd71d3Hbpls8U6x/7xf73irThsEfEeknHV9gLgBZi6A1IgrADSRB2IAnCDiRB2IEk+IhrBQZO/1SxftWN1xbrn+yp/1HMsWxXDBTrf3vNl4v18W+Wh79OvnNx3drkTbuL207YVh6am9i3pljvRhzZgSQIO5AEYQeSIOxAEoQdSIKwA0kQdiAJxtkrMOGpF4v1h96ZWax/smdLle1Uasnmk4r19W+Uv4r6xsN/VLf22p7yOPn0H/5Xsd5K+98HWEfGkR1IgrADSRB2IAnCDiRB2IEkCDuQBGEHknBE+0YUD/bUONGfbtv+ukX/BScX69vPLH/d87hHDyrWH/n6Ne+7p72u3Pb7xfqDp5XH0Qdefa1Yj5PrfwHxhm8WN9XsBY+UV8B7rIlV2h79w85lzZEdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnL0LjJv2kWJ94JX+Yv35W+qPlT9+6vLitvP+/hvF+iHXdu4z5Xj/mhpnt73c9lbb64Ysm2r7PtvP1H5PqbJhANUbzWn8jZL2nfX+UkmrImKOpFW1+wC62Ihhj4jVkvY9jzxb0ora7RWSzqm2LQBVa/Q76KZHxOba7ZckTa+3ou1FkhZJ0oGa2ODuADSr6XfjY/Advrrv8kXEsojojYjeHk1odncAGtRo2LfYniFJtd9bq2sJQCs0GvaVkhbWbi+UdE817QBolRFfs9u+VdLpkqbZ3ijpcklLJd1h+0JJL0g6r5VNjnUD215pavtd2xuf3/2YLz5RrL983bjyA+wpz7GO7jFi2CNiQZ0SV8cA+xEulwWSIOxAEoQdSIKwA0kQdiAJpmweA4665Om6tQuOKw+a/Othq4r10z5/UbE++fYHinV0D47sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5AE4+xjQGna5Fe+dlRx2/9d+XaxfumVNxXrf3neucV6/M+H6tZmfvdXxW3Vxq85z4AjO5AEYQeSIOxAEoQdSIKwA0kQdiAJwg4kwZTNyfX/ycnF+s2Xf69Ynz3+wIb3fcxNi4v1OddvLtZ3r9/Q8L7HqqambAYwNhB2IAnCDiRB2IEkCDuQBGEHkiDsQBKMs6MoTplbrB+8dGOxfusnftrwvo/8+VeK9d/7Tv3P8UvSwDPrG973/qqpcXbby21vtb1uyLIrbG+yvbb2c1aVDQOo3mhO42+UdOYwy38QEXNrP/dW2xaAqo0Y9ohYLam/Db0AaKFm3qBbbPvR2mn+lHor2V5ku8923y7taGJ3AJrRaNivk3S4pLmSNkv6fr0VI2JZRPRGRG+PJjS4OwDNaijsEbElIgYiYo+k6yXNq7YtAFVrKOy2Zwy5e66kdfXWBdAdRhxnt32rpNMlTZO0RdLltftzJYWkDZK+GhHlDx+LcfaxaNz0Q4r1F88/om5tzSVXF7f9wAjHoi8+f0ax/tr8V4r1sag0zj7iJBERsWCYxTc03RWAtuJyWSAJwg4kQdiBJAg7kARhB5LgI67omDs2lqdsnugDivW3Ymex/rlvXFz/se9eU9x2f8VXSQMg7EAWhB1IgrADSRB2IAnCDiRB2IEkRvzUG3LbM39usf7c58tTNh87d0Pd2kjj6CO5pv+EYn3iPX1NPf5Yw5EdSIKwA0kQdiAJwg4kQdiBJAg7kARhB5JgnH2Mc++xxfrT3yyPdV9/yopi/dQDy58pb8aO2FWsP9A/u/wAe0b8dvNUOLIDSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBIjjrPbninpJknTNThF87KIuNr2VEm3S5qlwWmbz4uI/2tdq3mNn31Ysf7cBR+rW7vi/NuK2/7xQdsa6qkKl23pLdbvv/qkYn3KivL3zuPdRnNk3y1pSUQcLekkSRfZPlrSpZJWRcQcSatq9wF0qRHDHhGbI+Lh2u3XJT0p6VBJZ0vae3nVCknntKhHABV4X6/Zbc+SdIKkNZKmR8Te6xFf0uBpPoAuNeqw2z5I0o8lXRwR24fWYnDCuGEnjbO9yHaf7b5d2tFUswAaN6qw2+7RYNBvjoi7aou32J5Rq8+QtHW4bSNiWUT0RkRvjyZU0TOABowYdtuWdIOkJyPiqiGllZIW1m4vlHRP9e0BqMpoPuJ6iqQvSXrM9trassskLZV0h+0LJb0g6byWdDgGjJ/1u8X6a38wo1g//+9+Uqz/2YfvKtZbacnm8vDYr/65/vDa1Bv/u7jtlD0MrVVpxLBHxC8lDTvfsyQmWwf2E1xBByRB2IEkCDuQBGEHkiDsQBKEHUiCr5IepfEzfqdurX/5pOK2X5t9f7G+YPKWhnqqwuJN84v1h6+bW6xP+9G6Yn3q64yVdwuO7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQRJpx9p1/VP7a4p1/3l+sX3bEvXVrZ3zwzYZ6qsqWgbfr1k5duaS47ZF//etifeqr5XHyPcUquglHdiAJwg4kQdiBJAg7kARhB5Ig7EAShB1IIs04+4Zzyn/Xnj7uzpbt+9pXDy/Wr77/jGLdA/W+yXvQkVc+X7c2Z8ua4rYDxSrGEo7sQBKEHUiCsANJEHYgCcIOJEHYgSQIO5CEI6K8gj1T0k2SpksKScsi4mrbV0j6U0kv11a9LCLqf+hb0sGeGieaWZ6BVlkTq7Q9+oe9MGM0F9XslrQkIh62PVnSQ7bvq9V+EBHfq6pRAK0zYtgjYrOkzbXbr9t+UtKhrW4MQLXe12t227MknSBp7zWYi20/anu57Sl1tllku8923y7taK5bAA0bddhtHyTpx5Iujojtkq6TdLikuRo88n9/uO0iYllE9EZEb48mNN8xgIaMKuy2ezQY9Jsj4i5JiogtETEQEXskXS9pXuvaBNCsEcNu25JukPRkRFw1ZPmMIaudK6k8nSeAjhrNu/GnSPqSpMdsr60tu0zSAttzNTgct0HSV1vQH4CKjObd+F9KGm7crjimDqC7cAUdkARhB5Ig7EAShB1IgrADSRB2IAnCDiRB2IEkCDuQBGEHkiDsQBKEHUiCsANJEHYgiRG/SrrSndkvS3phyKJpkra1rYH3p1t769a+JHprVJW9HRYRHx2u0Nawv2fndl9E9HasgYJu7a1b+5LorVHt6o3TeCAJwg4k0emwL+vw/ku6tbdu7Uuit0a1pbeOvmYH0D6dPrIDaBPCDiTRkbDbPtP2U7aftX1pJ3qox/YG24/ZXmu7r8O9LLe91fa6Icum2r7P9jO138POsdeh3q6wvan23K21fVaHeptp++e2n7D9uO1v1ZZ39Lkr9NWW563tr9ltj5P0tKTPSNoo6UFJCyLiibY2UoftDZJ6I6LjF2DYPlXSG5Juiohja8v+UVJ/RCyt/aGcEhGXdElvV0h6o9PTeNdmK5oxdJpxSedI+rI6+NwV+jpPbXjeOnFknyfp2YhYHxE7Jd0m6ewO9NH1ImK1pP59Fp8taUXt9goN/mdpuzq9dYWI2BwRD9duvy5p7zTjHX3uCn21RSfCfqik3wy5v1HdNd97SPqZ7YdsL+p0M8OYHhGba7dfkjS9k80MY8RpvNtpn2nGu+a5a2T682bxBt17zY+IT0n6rKSLaqerXSkGX4N109jpqKbxbpdhphn/rU4+d41Of96sToR9k6SZQ+5/vLasK0TEptrvrZLuVvdNRb1l7wy6td9bO9zPb3XTNN7DTTOuLnjuOjn9eSfC/qCkObZn2z5A0hckrexAH+9he1LtjRPZniTpDHXfVNQrJS2s3V4o6Z4O9vIu3TKNd71pxtXh567j059HRNt/JJ2lwXfkn5P0V53ooU5fn5D0SO3n8U73JulWDZ7W7dLgexsXSvqIpFWSnpH0n5KmdlFv/ybpMUmPajBYMzrU23wNnqI/Kmlt7eesTj93hb7a8rxxuSyQBG/QAUkQdiAJwg4kQdiBJAg7kARhB5Ig7EAS/w9pgMSoTFggTAAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "for data in train_dataset:\n",
    "    image, label = data\n",
    "    print(\"shape of image: \", image.shape)\n",
    "    plt.title(str(label))\n",
    "    plt.imshow(image[0])\n",
    "    break\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1.2 使用 paddle.io.Dataset 自定义数据集\n",
    "在实际的场景中，一般需要使用自有的数据来定义数据集，这时可以通过 paddle.io.Dataset 基类来实现自定义数据集。\n",
    "\n",
    "可构建一个子类继承自 paddle.io.Dataset ，并且实现下面的三个函数：\n",
    "\n",
    "1. \\_\\_init\\_\\_：完成数据集初始化操作，将磁盘中的样本文件路径和对应标签映射到一个列表中。\n",
    "\n",
    "2. \\_\\_getitem\\_\\_：定义指定索引（index）时如何获取样本数据，最终返回对应 index 的单条数据（样本数据、对应的标签）。\n",
    "\n",
    "3. \\_\\_len\\_\\_：返回数据集的样本总数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    }
   ],
   "source": [
    "# # 下载原始的 MNIST 数据集并解压\n",
    "# ! wget https://paddle-imagenet-models-name.bj.bcebos.com/data/mnist.tar\n",
    "# ! tar -xf mnist.tar\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_custom_dataset images:  60000 test_custom_dataset images:  10000\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import cv2\n",
    "import numpy as np\n",
    "from paddle.io import Dataset\n",
    "from paddle.vision.transforms import Normalize\n",
    "\n",
    "class MyDataset(Dataset):\n",
    "    def __init__(self, data_dir, label_path, transform=None):\n",
    "        super().__init__()\n",
    "        self.data_list = []\n",
    "        with open(label_path, encoding='utf-8') as f:\n",
    "            for line in f.readlines():\n",
    "                image_path, label = line.split(\"\\t\")\n",
    "                image_path = os.path.join(data_dir, image_path)\n",
    "                self.data_list.append([image_path, label])\n",
    "        self.transform = transform\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        image_path, label = self.data_list[idx]\n",
    "        image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)\n",
    "        image = image.astype(\"float32\")\n",
    "        if self.transform is not None:\n",
    "            image = self.transform(image)\n",
    "        \n",
    "        label = int(label)\n",
    "        return image, label\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.data_list)\n",
    "\n",
    "# 定义图像归一化处理方法，这里的CHW指图像格式需为 [C通道数，H图像高度，W图像宽度]\n",
    "transform = Normalize(mean=[127.5], std=[127.5], data_format=\"CHW\")\n",
    "# 打印数据集样本数\n",
    "train_custom_dataset = MyDataset(\n",
    "    \"../datasets/mnist/train\", \"../datasets/mnist/train/label.txt\", transform\n",
    ")\n",
    "test_custom_dataset = MyDataset(\"../datasets/mnist/val\", \"../datasets/mnist/val/label.txt\", transform)\n",
    "print(\n",
    "    \"train_custom_dataset images: \",\n",
    "    len(train_custom_dataset),\n",
    "    \"test_custom_dataset images: \",\n",
    "    len(test_custom_dataset),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "shape of image:  (1, 28, 28)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8rg+JYAAAACXBIWXMAAAsTAAALEwEAmpwYAAAUp0lEQVR4nO3de3Cc1XkG8OfZlXyTLXxFKODEhjEY06QmCOM2hJIyIcC0GDpTJrRh3A6paIFMmOGPAm0nZEKntJOQMhQyKJhgWm6ZBopJDQE8JQwtOAjqYIO5GLDBQpYwvskXWdrdt39oTQXoe4/Yb3e/dc7zm9FY2ne/3eOVHn2rffecQzODiPzmy2U9ABGpD4VdJBIKu0gkFHaRSCjsIpFQ2EUiobCLREJhlzGRfIrkIMm95Y/Xsh6TpKOwi+dKM5ta/jgh68FIOgq7SCQUdvH8A8ntJP+b5JlZD0bSod4bL2MheRqAVwAMAfg6gH8BsNjM3sx0YFIxhV3GheRjAP7TzG7JeixSGT2Nl/EyAMx6EFI5hV0+geR0kl8jOYlkE8k/BXAGgMeyHptUrinrAUhDagZwA4CFAIoAXgVwgZm9numoJBX9zS4SCT2NF4mEwi4SCYVdJBIKu0gk6vpq/AROtEloqedd/r9ghzh0BeeFTL3GWRv0vycM1K1USj622f/Rt+GCW29Ug9iHITs45gOTKuwkzwFwM4A8gDvM7Ebv+pPQgtN4Vpo7rPzQfN6/QqheLCaWrJDxD0aKxwVpuzGh+05x+2ye4NcnNLv10r59ibWm2W3usYVtfW69Ua21NYm1ip/Gk8wDuBXAuQAWAbiY5KJKb09EaivN3+xLAGwys7fMbAjA/QCWVWdYIlJtacJ+NIB3R329tXzZR5DsJNlNsnsYB1PcnYikUfNX482sy8w6zKyjGRNrfXcikiBN2HsAzB319THly0SkAaUJ+/MAFpCcT3ICRhY4WFWdYYlItVXcejOzAskrAfwCI623O83s5VSjyQXaX55ScmsMGEd7LOv2WVbStO0QbmmmaUtaYdivDw9VfNuh1lpu2jT/vof8+7aDjff6VKo+u5mtBrC6SmMRkRrS22VFIqGwi0RCYReJhMIuEgmFXSQSCrtIJOq/umyaaaq55GMteepy7QXeH5ALTMUMLvrpTK8FAPPqNV5QNNRHz02ZklgrDfq96NDjxgmBKbBTJifWQn320t69br3Wj2st6MwuEgmFXSQSCrtIJBR2kUgo7CKRUNhFIlHf1hv9KZGhNo7bXkszPRYITpH1WoahaZ6lwHTH4DTRUqDNk6YNFHjcvHYnEGj7ASjt3/+ph/ThsYOB78ngoF/fs6fi+w49pvkZM9x6cefOyu+7RnRmF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUiUecprvR3S63hcs7BLXqHAnNknb5rmiWNgXH00WspMDfYCunGxonJuwCFlltuaj/KrRd6t7n1fNuRibXSjl3uscHvaaZzqiujM7tIJBR2kUgo7CKRUNhFIqGwi0RCYReJhMIuEom69tkJgM688FQd3VC/eDjQw89yaeDAXHqvVw0AZPKSy5ycvJwyALB1qltHLnA+GPK3VX6zc17yfS8acI+dOW2fW59+efJtA8Br30uec77pK4+7x4Ycf/dfufX51zyb6vZrIVXYSW4GMACgCKBgZh3VGJSIVF81zuxfMbPtVbgdEakh/c0uEom0YTcAj5N8gWTnWFcg2Umym2T3EPz3QotI7aR9Gn+6mfWQPBLAEyRfNbOnR1/BzLoAdAHAEblZh98GWSK/IVKd2c2sp/xvP4CHACypxqBEpPoqDjvJFpLTDn0O4GwAG6o1MBGprjRP49sAPFTumzcBuNfMHvMOMFhwbfiKMfB7K7QufPD2k98fkJ892z92RqtbPnDsTLe+c4G/dfGeBcn/t/mLet1jL//sU2793Cl+o2VKzt822bO35K/73rVrkVt/5PbPu/VNJ/0ksdZb8Ldk/vY7y9z6Uc+m/HnKQMVhN7O3APx2FcciIjWk1ptIJBR2kUgo7CKRUNhFIqGwi0SivktJW3hbZleKbZMtZestP316Yq10TPKSxQDw9w8mt4AA4Ngm/zGZkZ/i1ovO9N58qCUZ5LfWdpcOuPU8kr9nucC5ZsU957j1ydv9N2QuPfCXibWWXn9qbu6g//My+ZlfufVGpDO7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhKJOm/ZnFINt01Gzu/Tl/bvT6zx9c3usc8dOM6tnzL9Xbe+v+T/394uJPeET5rgLyUdmur5b3v8iY1r+he69QeO//fE2qD5ve65N73g1m0otK1y8s9Lbor/3gXv+w0AbPbff5D657EGdGYXiYTCLhIJhV0kEgq7SCQUdpFIKOwikVDYRSJxePXZs1Ry5k4X/bnPP7npD9z6Tb/nL6k86RW/V/7yt25z657vv3+Gf9tL/R8RG+5x6xctGXNXMADA69+c5B57QvF/3Xpom22vl86mdD/6uRb/e1LcpT67iGREYReJhMIuEgmFXSQSCrtIJBR2kUgo7CKRqH+f3Vn7PdQ3rSln7XUAsBTLzs95wN+2vu3RaW690POeW1948iWJtftPvcM99sm7l7r19pw/pzzoV+sTSyfumO8eWgztMZBmDYKJE/3bDiju2p3q+CwEz+wk7yTZT3LDqMtmknyC5Bvlf2fUdpgiktZ4nsbfBeDjW3NcA2CNmS0AsKb8tYg0sGDYzexpADs+dvEyACvLn68EcEF1hyUi1Vbp3+xtZtZb/nwbgLakK5LsBNAJAJPgr/slIrWT+tV4MzMAia+smVmXmXWYWUcz0r0oIiKVqzTsfSTbAaD8b3/1hiQitVBp2FcBWF7+fDmAh6szHBGpleDf7CTvA3AmgNkktwL4DoAbAfyU5KUAtgC4aNz36O4X7ve6a9qHD92202i3UI9+wJ/bXBoYcOv51la3fnBn8rzwxYF+8twL33br9h/+3vOFLf6a9/kZyV3Z4ib/vt33ZADItQTWfnceVzt40D3WGzcAFHfudOuNKBh2M7s4oXRWlcciIjWkt8uKREJhF4mEwi4SCYVdJBIKu0gk6j/FNdCmOiyFljRuaXHrpX37Ut398Zc9n1j7oxO/6h778+Mfdeu/f8I33XpzoPWWpkXVdPRn3Hpo6q+3XLQFps+Gxh2aIhtq7WVBZ3aRSCjsIpFQ2EUiobCLREJhF4mEwi4SCYVdJBIZ9NkzXC46I6E+eqhnW9yzx63npx+RWNv93c+6x/bf5Y/tqlvvdevffdXfjvrA2tmJtXm3bnSPLWz1t4MOyU1LXqI79RRVbwvvBqUzu0gkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4SCVod+96tnGmnMb5Fab151QCAvL/1cC5lH94z+IdL3Povb+9y6zuLydsiA8CMfPJyz4tuu9w9dv69/nz1wttb3Hqq93QEtoNGKcUe3jW01tZgj+0Ycw1undlFIqGwi0RCYReJhMIuEgmFXSQSCrtIJBR2kUioz94AQn14C82ddnq+oTXrkfN/3xe/cJxbb7qh362vPmF1Yq2/6M+l/51fXunWF/7tB269sPmdxFrwMQ+sK5+blLxNNgCUBgfdeq2k6rOTvJNkP8kNoy67nmQPyXXlj/OqOWARqb7xPI2/C8A5Y1z+QzNbXP5I/vUtIg0hGHYzexrAjjqMRURqKM0LdFeSfKn8NH9G0pVIdpLsJtk9jMbb/0okFpWG/UcAjgOwGEAvgB8kXdHMusysw8w6muFP6BCR2qko7GbWZ2ZFMysB+DEAf+qUiGSuorCTbB/15YUANiRdV0QaQ7DPTvI+AGcCmA2gD8B3yl8vBmAANgO4zMx6Q3d2RG6WLZ2U3KUrDQ37N+D0k9P2TUPYPCH5toeHanbb45H2/l2Bed351qlu/b1LTkqsrbv2toqGdMhZr5zv1pvP3ZZYCz1mh+P+64DfZw9uEmFmF49x8YrUoxKRutLbZUUiobCLREJhF4mEwi4SCYVdJBJ13bLZzGo29a/WrTevVZN2GmlpYKCSIX3I27LZAu1MG0rXgirtO+DW2275n+Tite6hKFrJrf/d/J+79Wv/uDOx1nrfWvfYUGstP2umWy9+0HjTSXRmF4mEwi4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUiUdc+O0l3CV5O8Kd6lg4k9+jT9u/zc+b4V3D6rqEtk0NTWEN9+tI+f8nl4q7dzo2n23rYvniiW3/jG8lbMgPA6aduTKxtHPK3ez62udmt7yq1uvXWe59LrIXeP5Cb7C8V3Yh99BCd2UUiobCLREJhF4mEwi4SCYVdJBIKu0gkFHaRSDTWfPZAr9zrV4f6pij6/eTi++/7xzvS9slDyxrnZyTurjVy+/uT+9W5eXPdYzde7d/2w2ff4tan5/x1AtrzkxNruwNbUU+k32e/o+fLbn1ks6KxhearFwP1pvaj3HqhN3kZ66zozC4SCYVdJBIKu0gkFHaRSCjsIpFQ2EUiobCLRCLYZyc5F8DdANowskVzl5ndTHImgAcAzMPIts0XmdnO2g0VYD75d1NoPnuoD++tvQ4Axe0fJN+30+cGxtEn3+v34THB7ze/c8UpibUbl9/lHnt+iz/23sB6++1N/pbNnhW7Frv1Vd87y623PvqyW6/lNtuN2EcPGc+ZvQDgajNbBGApgCtILgJwDYA1ZrYAwJry1yLSoIJhN7NeM3ux/PkAgI0AjgawDMDK8tVWArigRmMUkSr4VH+zk5wH4GQAawG0mdmh9yNuw8jTfBFpUOMOO8mpAH4G4Coz+8iia2ZmGPl7fqzjOkl2k+wehv9+YxGpnXGFnWQzRoJ+j5k9WL64j2R7ud4OoH+sY82sy8w6zKyjGYHJKiJSM8GwkySAFQA2mtlNo0qrACwvf74cwMPVH56IVMt4prh+CcAlANaTXFe+7DoANwL4KclLAWwBcFHohpjPI9+a3OLylooGENz62BOc0hjYutjT1Hakf9tO2w4Atv/5qW79dy/rdusPt/vTUH2BpaYDvrH5TLf+5i0LE2sznnzTPXbq+8lLQQOAv6Gzz1vSHABygVZsYVtfinvPRjDsZvYMACaU/UaoiDQMvYNOJBIKu0gkFHaRSCjsIpFQ2EUiobCLRKKuS0mjVIINJve7Q71wb+Hh0DTSkOJOf3Zu07HzEmtv/eM099jbT/mFW5/X9JRbn8KkzueIZvpLWXsufed0t/7rOz7v1o98xO+Vt/Yl98r9xb2Bps/5y2AX3/N73d6U6JBQHz20DXfaKbS1oDO7SCQUdpFIKOwikVDYRSKhsItEQmEXiYTCLhKJhtqyObjtsrPFb6hPfmDZErfe/C1/aeDzP/NCYu2K6e+6x+4s+ss1T8n5/+/Q1sXPDSZ3rP/kkSvcY0/85+RtjQFg1tvPuvVQr9yTmzLFrRe2+I9r2q2yPfk5c9x6mi2+s6Izu0gkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4SifrOZweAXPI65aH57J6mY4526+992f+9tnrBfW79+Obknm5/0e/n3tB3plt/ZP0X3Hpzrz93esHtWxNrC3dvdI8t7Nrt1kPY5P8ImbPlc2ir65A0ffSQw7GPHqIzu0gkFHaRSCjsIpFQ2EUiobCLREJhF4mEwi4SCZp5q7EDJOcCuBtAG0aWbu8ys5tJXg/gLwAcakheZ2arvdtqzc20pU1fS6znpvrzk21oOLGWtmebm+av/Y5h576dOfoAkJ81060XP9jh1kPvIShs7UmspV7fPLBmPQI/P1Jfa20N9tiOMb9p43lTTQHA1Wb2IslpAF4g+US59kMz+361BioitRMMu5n1Augtfz5AciMA/1QjIg3nU/3NTnIegJMBrC1fdCXJl0jeSXLM/ZdIdpLsJtk9bJW/HVZE0hl32ElOBfAzAFeZ2R4APwJwHIDFGDnz/2Cs48ysy8w6zKyjmYE15kSkZsYVdpLNGAn6PWb2IACYWZ+ZFc2sBODHAPwVHUUkU8GwkySAFQA2mtlNoy5vH3W1CwFsqP7wRKRaxvNq/JcAXAJgPcl15cuuA3AxycUYacdtBnBZ8JbMn/JYTDHdMthiKvqLHpcGBiq+79ykSW491FoLKbznL3PtsUJyyxAIL9+dZtqxNJbxvBr/DICx+nZuT11EGoveQScSCYVdJBIKu0gkFHaRSCjsIpFQ2EUiUdelpNnchKbZbYn1wra+im87OFWzhtJOcfW2ogaA4u49bt1bztl7X8N47jtEffrDh87sIpFQ2EUiobCLREJhF4mEwi4SCYVdJBIKu0gkgktJV/XOyPcBbBl10WwA2+s2gE+nUcfWqOMCNLZKVXNsnzOzOWMV6hr2T9w52W1mHZkNwNGoY2vUcQEaW6XqNTY9jReJhMIuEomsw96V8f17GnVsjTouQGOrVF3Glunf7CJSP1mf2UWkThR2kUhkEnaS55B8jeQmktdkMYYkJDeTXE9yHcnujMdyJ8l+khtGXTaT5BMk3yj/O+YeexmN7XqSPeXHbh3J8zIa21yS/0XyFZIvk/x2+fJMHztnXHV53Or+NzvJPIDXAXwVwFYAzwO42MxeqetAEpDcDKDDzDJ/AwbJMwDsBXC3mf1W+bJ/ArDDzG4s/6KcYWZ/3SBjux7A3qy38S7vVtQ+eptxABcA+DNk+Ng547oIdXjcsjizLwGwyczeMrMhAPcDWJbBOBqemT0N4OPbySwDsLL8+UqM/LDUXcLYGoKZ9ZrZi+XPBwAc2mY808fOGVddZBH2owG8O+rrrWis/d4NwOMkXyDZmfVgxtBmZr3lz7cBSF7nKxvBbbzr6WPbjDfMY1fJ9udp6QW6TzrdzL4I4FwAV5SfrjYkG/kbrJF6p+Paxrtexthm/ENZPnaVbn+eVhZh7wEwd9TXx5Qvawhm1lP+tx/AQ2i8raj7Du2gW/63P+PxfKiRtvEea5txNMBjl+X251mE/XkAC0jOJzkBwNcBrMpgHJ9AsqX8wglItgA4G423FfUqAMvLny8H8HCGY/mIRtnGO2mbcWT82GW+/bmZ1f0DwHkYeUX+TQB/k8UYEsZ1LIBflz9eznpsAO7DyNO6YYy8tnEpgFkA1gB4A8CTAGY20Nj+FcB6AC9hJFjtGY3tdIw8RX8JwLryx3lZP3bOuOryuOntsiKR0At0IpFQ2EUiobCLREJhF4mEwi4SCYVdJBIKu0gk/g81xM9ks5Ld8AAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "for data in train_custom_dataset:\n",
    "    image, label = data\n",
    "    print(\"shape of image: \", image.shape)\n",
    "    plt.title(str(label))\n",
    "    plt.imshow(image[0])\n",
    "    break\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch_id: 0, 训练数据shape: [64, 1, 28, 28], 标签数据shape: [64]\n"
     ]
    }
   ],
   "source": [
    "train_loader = paddle.io.DataLoader(\n",
    "    train_custom_dataset,\n",
    "    batch_size=64,\n",
    "    shuffle=True,\n",
    "    num_workers=1,\n",
    "    drop_last=True,\n",
    ")\n",
    "# 调用 DataLoader 迭代读取数据\n",
    "for batch_id, data in enumerate(train_loader()):\n",
    "    images, labels = data\n",
    "    print(\n",
    "        \"batch_id: {}, 训练数据shape: {}, 标签数据shape: {}\".format(\n",
    "            batch_id, images.shape, labels.shape\n",
    "        )\n",
    "    )\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "BatchSampler 每轮迭代返回一个索引列表\n",
      "[52504, 52484, 40159, 59830, 52638, 24428, 15767, 32915]\n",
      "在 DataLoader 中使用 BatchSampler，返回索引对应的一组样本和标签数据 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.8/dist-packages/ipykernel/ipkernel.py:283: DeprecationWarning: `should_run_async` will not call `transform_cell` automatically in the future. Please pass the result to `transformed_cell` argument and any exception that happen during thetransform in `preprocessing_exc_tuple` in IPython 7.17 and above.\n",
      "  and should_run_async(code)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "batch_id: 0, 训练数据shape: [8, 1, 28, 28], 标签数据shape: [8]\n"
     ]
    }
   ],
   "source": [
    "from paddle.io import BatchSampler\n",
    "\n",
    "bs = BatchSampler(\n",
    "    train_custom_dataset, batch_size=8, shuffle=True, drop_last=True\n",
    ")\n",
    "\n",
    "print(\"BatchSampler 每轮迭代返回一个索引列表\")\n",
    "for batch_indices in bs:\n",
    "    print(batch_indices)\n",
    "    break\n",
    "\n",
    "# 在 DataLoader 中使用 BatchSampler 获取采样数据\n",
    "train_loader = paddle.io.DataLoader(\n",
    "    train_custom_dataset, batch_sampler=bs, num_workers=1\n",
    ")\n",
    "\n",
    "print(\"在 DataLoader 中使用 BatchSampler，返回索引对应的一组样本和标签数据 \")\n",
    "for batch_id, data in enumerate(train_loader()):\n",
    "    images, labels = data\n",
    "    print(\n",
    "        \"batch_id: {}, 训练数据shape: {}, 标签数据shape: {}\".format(\n",
    "            batch_id, images.shape, labels.shape\n",
    "        )\n",
    "    )\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "-----------------顺序采样----------------\n",
      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
      "[10, 11, 12, 13, 14, 15, 16, 17, 18, 19]\n",
      "[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n",
      "[30, 31, 32, 33, 34, 35, 36, 37, 38, 39]\n",
      "[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]\n",
      "[50, 51, 52, 53, 54, 55, 56, 57, 58, 59]\n",
      "[60, 61, 62, 63, 64, 65, 66, 67, 68, 69]\n",
      "[70, 71, 72, 73, 74, 75, 76, 77, 78, 79]\n",
      "[80, 81, 82, 83, 84, 85, 86, 87, 88, 89]\n",
      "[90, 91, 92, 93, 94, 95, 96, 97, 98, 99]\n",
      "-----------------随机采样----------------\n",
      "[74, 33, 98, 24, 0, 38, 65, 55, 83, 27]\n",
      "[61, 20, 53, 11, 26, 52, 89, 97, 37, 62]\n",
      "[41, 87, 12, 77, 71, 78, 44, 80, 39, 40]\n",
      "[49, 5, 86, 30, 66, 35, 7, 64, 92, 82]\n",
      "[81, 54, 1, 94, 96, 63, 28, 31, 56, 34]\n",
      "[68, 69, 13, 59, 88, 45, 32, 43, 14, 25]\n",
      "[3, 48, 70, 47, 79, 10, 57, 73, 93, 8]\n",
      "[90, 72, 22, 16, 91, 67, 17, 84, 4, 15]\n",
      "[85, 51, 50, 95, 6, 75, 58, 36, 42, 2]\n",
      "[99, 19, 29, 21, 18, 76, 60, 23, 9, 46]\n",
      "-----------------分布式采样----------------\n",
      "[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\n",
      "[20, 21, 22, 23, 24, 25, 26, 27, 28, 29]\n",
      "[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]\n",
      "[60, 61, 62, 63, 64, 65, 66, 67, 68, 69]\n",
      "[80, 81, 82, 83, 84, 85, 86, 87, 88, 89]\n"
     ]
    }
   ],
   "source": [
    "from paddle.io import (\n",
    "    SequenceSampler,\n",
    "    RandomSampler,\n",
    "    BatchSampler,\n",
    "    DistributedBatchSampler,\n",
    ")\n",
    "\n",
    "\n",
    "class RandomDataset(paddle.io.Dataset):\n",
    "    def __init__(self, num_samples):\n",
    "        self.num_samples = num_samples\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        image = np.random.random([784]).astype(\"float32\")\n",
    "        label = np.random.randint(0, 9, (1,)).astype(\"int64\")\n",
    "        return image, label\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.num_samples\n",
    "\n",
    "\n",
    "train_dataset = RandomDataset(100)\n",
    "\n",
    "print(\"-----------------顺序采样----------------\")\n",
    "sampler = SequenceSampler(train_dataset)\n",
    "batch_sampler = BatchSampler(sampler=sampler, batch_size=10)\n",
    "\n",
    "for index in batch_sampler:\n",
    "    print(index)\n",
    "\n",
    "print(\"-----------------随机采样----------------\")\n",
    "sampler = RandomSampler(train_dataset)\n",
    "batch_sampler = BatchSampler(sampler=sampler, batch_size=10)\n",
    "\n",
    "for index in batch_sampler:\n",
    "    print(index)\n",
    "\n",
    "print(\"-----------------分布式采样----------------\")\n",
    "batch_sampler = DistributedBatchSampler(\n",
    "    train_dataset, num_replicas=2, batch_size=10\n",
    ")\n",
    "\n",
    "for index in batch_sampler:\n",
    "    print(index)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![](https://www.paddlepaddle.org.cn/documentation/docs/zh/_images/data_pipeline.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
